package com.sqi.reactive.common.repository

import com.google.common.collect.Lists
import com.google.common.collect.Maps
import com.sqi.reactive.common.result.PageResult
import com.sqi.reactive.common.util.getPostfix
import com.sqi.reactive.common.util.getPrefix
import org.springframework.data.domain.Pageable
import org.springframework.data.domain.Sort
import org.springframework.data.r2dbc.convert.R2dbcConverter
import org.springframework.data.r2dbc.core.DatabaseClient
import org.springframework.data.r2dbc.core.R2dbcEntityTemplate
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy
import org.springframework.data.r2dbc.repository.support.SimpleR2dbcRepository
import org.springframework.data.relational.core.query.Criteria
import org.springframework.data.relational.core.query.Criteria.empty
import org.springframework.data.relational.core.query.Query
import org.springframework.data.relational.repository.query.RelationalEntityInformation
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneId

/**
 * @author sjl
 * @date 2020/5/29
 */
class DynamicQueryRepositoryImpl<T, ID>(private val entity: RelationalEntityInformation<T, ID>,
                                        databaseClient: DatabaseClient,
                                        converter: R2dbcConverter,
                                        accessStrategy: ReactiveDataAccessStrategy
) : SimpleR2dbcRepository<T, ID>(entity, databaseClient, converter, accessStrategy), DynamicQueryRepository<T, ID> {

    private final val dynamicOperations: R2dbcEntityTemplate = R2dbcEntityTemplate(databaseClient, accessStrategy)
    private final val fieldContainer: MutableMap<String, Class<*>>

    override fun findAll(params: Map<String, Any>?): Flux<T> =
            this.findAll(createQuery(params))

    override fun findAll(params: Map<String, Any>?, pageable: Pageable): Mono<PageResult<T>> =
            this.findAll(createQuery(params), pageable)

    override fun findAll(query: Query): Flux<T> =
            this.dynamicOperations.select(query, this.entity.javaType)

    override fun findAll(query: Query, pageable: Pageable): Mono<PageResult<T>> =
            this.dynamicOperations.select(query.addPageable(pageable), this.entity.javaType)
                    .collectList()
                    .zipWhen({
                        if (it.size < pageable.pageSize) {
                            Mono.just(pageable.pageNumber.toLong() * pageable.pageSize + it.size)
                        } else {
                            this.dynamicOperations.count(query, this.entity.javaType)
                        }
                    }) { records, total ->
                        PageResult(total, records)
                    }

    private fun Query.addPageable(pageable: Pageable): Query =
            if (pageable.isUnpaged) {
                this
            } else {
                this.limit(pageable.pageSize).offset(pageable.offset).sort(pageable.sort)
            }


    /**
     * 创建动态查询Query
     */
    @Suppress("IMPLICIT_CAST_TO_ANY")
    private fun createQuery(params: Map<String, Any>?): Query =
            if (params == null) {
                Query.empty()
            } else {
                val orders = Lists.newArrayList<Sort.Order>()
                var criteria = empty()
                params.forEach { (key, value) ->
                    val fieldName = key.getPostfix()
                    val field = fieldContainer[fieldName]
                    val prefix = getPrefix(key)
                    var tValue = value
                    if (field != null) {
                        if (field == LocalDateTime::class.java && prefix != "ORDER") {
                            val timestamp: Long = value.toString().toLong()
                            val instant = Instant.ofEpochMilli(timestamp)
                            tValue = LocalDateTime.ofInstant(instant, ZoneId.systemDefault())
                        }
                        criteria = when (prefix) {
                            "EQ" -> criteria.and(fieldName).`is`(tValue)
                            "LT" -> criteria.and(fieldName).lessThan(tValue)
                            "ELT" -> criteria.and(fieldName).lessThanOrEquals(tValue)
                            "GT" -> criteria.and(fieldName).greaterThan(tValue)
                            "EGT" -> criteria.and(fieldName).greaterThanOrEquals(tValue)
                            "LK" -> criteria.and(fieldName).like("%${tValue}%")
                            "LLK" -> criteria.and(fieldName).like("%${tValue}")
                            "RLK" -> criteria.and(fieldName).like("${tValue}%")
                            "IN" -> {
                                criteria.and(fieldName).`in`(tValue.toString().split(","))
                            }
                            "ORDER" -> {
                                orders.add(
                                    Sort.Order(
                                        Sort.Direction.valueOf((tValue as String).toUpperCase()),
                                        fieldName
                                    )
                                )
                                criteria
                            }
                            else -> criteria
                        } as Criteria
                    }
                }
                Query.query(criteria).sort(Sort.by(orders))
            }

    private fun getPrefix(key: String) = key.getPrefix("EQ").toUpperCase()

    init {
        fieldContainer = Maps.newHashMap()
        val entityClass = this.entity.javaType;
        entityClass.declaredFields.forEach {
            fieldContainer[it.name] = it.type
        }
        entityClass.superclass?.declaredFields?.forEach {
            fieldContainer[it.name] = it.type
        }
    }

    override fun metaInfo(): Mono<Map<String, Class<*>>> = Mono.just(fieldContainer)
}

data class A(var ldt: LocalDateTime?)

fun main() {
    val entityClass = A::class.java

    entityClass.declaredFields.forEach {
        println(it.type)
        println(it.type == LocalDateTime::class.java)
    }


}